{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2025, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}
#include <cooperative_groups.h>
#include <cooperative_groups/details/partitioning.h>
#include <cooperative_groups/memcpy_async.h>
#include <cooperative_groups/reduce.h>
#include <cuda_runtime.h>

#include "kernel_{{kernel.name}}.h"
#include "memops.cuh"

namespace cg = cooperative_groups;

namespace caspar {

__global__ void
__launch_bounds__(1024, 1)
{{kernel.name}}_kernel(
    {% for accessor in kernel.accessors %}
    {% for name, typ in accessor.kernel_sig.items() %}
      {{typ}} {{name}},
    {% endfor %}
    {% endfor %}
    size_t problem_size) {
  const int global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  {% if kernel.shared_size_req > 0 %}
  __shared__ float inout_shared[{{kernel.shared_size_req}}];
  {% endif %}
  {% for accessor in kernel.accessors %}
  {% filter indent(width=4) %}
    {{accessor.prep_lines}}
  {% endfilter %}
  {% endfor %}

  float {{','.join(kernel.registers)}};
  if (global_thread_idx < problem_size){
  {% for code_line in kernel.code_lines %}
  {% filter indent(width=4) %}
    {{code_line}}
  {% endfilter %}
  {% endfor %}
  };
}


void {{kernel.name}} (
  {% for accessor in kernel.accessors %}
  {% for name, typ in accessor.kernel_sig.items() %}
  {{typ}} {{name}},
  {% endfor %}
  {% endfor %}
  size_t problem_size) {
  if (problem_size == 0) {
    return;
  }
  const int n_blocks = (problem_size + 1024 - 1) / 1024;
  {{kernel.name}}_kernel<<<n_blocks, 1024>>>(
      {% for accessor in kernel.accessors %}
      {% for name in accessor.kernel_sig %}
      {{name}},
      {% endfor %}
      {% endfor %}
      problem_size
  );
}

}  // namespace caspar
